package com.baomidou.springboot.utils;

import com.alibaba.fastjson.JSONObject;
import com.baomidou.springboot.config.ConfigurationManager;
import com.baomidou.springboot.constant.Constants;
import com.baomidou.springboot.exception.ParameterException;
import com.baomidou.springboot.spark.AnalysisMain;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.execution.columnar.LONG;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Spark工具类
 * @author Administrator
 *
 */
public class SparkUtils {
	protected final static Logger logger = LoggerFactory.getLogger(SparkUtils.class);
	/**
	 * 根据当前是否本地测试的配置
	 * 决定，如何设置SparkConf的master
	 */
	public static void setMaster(SparkConf conf) {
		boolean local = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL);
		if(local) {
			conf.setMaster("local");  
		}  
	}
	
	/**
	 * 获取SQLContext
	 * 如果spark.local设置为true，那么就创建SQLContext；否则，创建HiveContext
	 * @param sc
	 * @return
	 */
	public static SQLContext getSQLContext(SparkContext sc) {
		boolean local = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL);
		if(local) {
			return new SQLContext(sc);
		} else {
			return new HiveContext(sc);
		}
	}
	
//	/**
//	 * 生成模拟数据
//	 * 如果spark.local配置设置为true，则生成模拟数据；否则不生成
//	 * @param sc
//	 * @param sqlContext
//	 */
//	public static void mockData(JavaSparkContext sc, SQLContext sqlContext) {
//		boolean local = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL);
//		if(local) {
//			MockData.mock(sc, sqlContext);
//		}
//	}

	/**
	 * 加载本地测试数据到注册表
	 * @param sc
	 * @param sqlContext
	 */
	public static void loadLocalTestDataToTmpTable(JavaSparkContext sc,SQLContext sqlContext){
		List<Row> rows = new ArrayList<Row>();
		String sessionPath = ConfigurationManager.getProperty(Constants.SPARK_LCOAL_SESSION_DATA_PATH);
		String userPath = ConfigurationManager.getProperty(Constants.SPARK_LCOAL_USER_DATA_PATH);
		JavaRDD<String> lines1 = sc.textFile(sessionPath);
		JavaRDD<String> lines2 = sc.textFile(userPath);
		/**
		 * 在RDD的基础上创建类型为Row的RDD，
		 */
		JavaRDD<Row> rowsRDD1  = lines1.map(new Function<String, Row>() {
			private static final long serialVersionUID = 1L;
			public Row call( String line )
					throws Exception {
				String[] split = line.split(" ");
				String id1 = split[7];
				String id2 = split[8];
				Long _id1 = null;
				Long _id2 = null;
				if (! "null".equals(id1)){
					_id1 = Long.valueOf(id1);
				}
				if (! "null".equals(id2)){
					_id2 = Long.valueOf(id2);
				}
				return RowFactory.create(
						split[0],
						Long.valueOf(split[1]),
						split[2],
						Long.valueOf(split[3]),
						split[4]+" "+split[5],
						split[6],
						_id1,
						_id2,
						split[9],
						split[10],
						split[11],
						split[12]
				);
			}
		});
		StructType schema1 = DataTypes.createStructType(Arrays.asList(
				DataTypes.createStructField("session_id", DataTypes.StringType, true),
				DataTypes.createStructField("user_id", DataTypes.LongType, true),
				DataTypes.createStructField("date", DataTypes.StringType, true),
				DataTypes.createStructField("page_id", DataTypes.LongType, true),
				DataTypes.createStructField("action_time", DataTypes.StringType, true),
				DataTypes.createStructField("search_keyword", DataTypes.StringType, true),
				DataTypes.createStructField("click_category_id", DataTypes.LongType, true),
				DataTypes.createStructField("click_product_id", DataTypes.LongType, true),
				DataTypes.createStructField("order_category_ids", DataTypes.StringType, true),
				DataTypes.createStructField("order_product_ids", DataTypes.StringType, true),
				DataTypes.createStructField("pay_category_ids", DataTypes.StringType, true),
				DataTypes.createStructField("pay_product_ids", DataTypes.StringType, true)));

		DataFrame df1 = sqlContext.createDataFrame(rowsRDD1, schema1);
		df1.registerTempTable("user_visit_action");
		logger.info("user_visit_action表注册成功！");

		JavaRDD<Row> rowsRDD2  = lines2.map(new Function<String, Row>() {
			private static final long serialVersionUID = 1L;
			public Row call( String line )
					throws Exception {
				String[] split = line.split(" ");
				return RowFactory.create(
						Long.valueOf(split[0]),
						split[1],
						split[2],
						Integer.valueOf(split[3]),
						split[4],
						split[5],
						split[6]
				);
			}
		});

		StructType schema2 = DataTypes.createStructType(Arrays.asList(
				DataTypes.createStructField("user_id", DataTypes.LongType, true),
				DataTypes.createStructField("username", DataTypes.StringType, true),
				DataTypes.createStructField("name", DataTypes.StringType, true),
				DataTypes.createStructField("age", DataTypes.IntegerType, true),
				DataTypes.createStructField("professional", DataTypes.StringType, true),
				DataTypes.createStructField("city", DataTypes.StringType, true),
				DataTypes.createStructField("sex", DataTypes.StringType, true)));
		DataFrame df2 = sqlContext.createDataFrame(rowsRDD2, schema2);
		df2.registerTempTable("user_info");
		logger.info("user_info表注册成功！");
	}

	/**
	 * 获取指定日期范围内的用户行为数据RDD
	 * @param sqlContext
	 * @param taskParam
	 * @return
	 */
	public static JavaRDD<Row> getActionRDDByDateRange(
			SQLContext sqlContext, JSONObject taskParam) {
		String startDate = ParamUtils.getParamWithSplit(taskParam, Constants.PARAM_START_DATE);
		String endDate = ParamUtils.getParamWithSplit(taskParam, Constants.PARAM_END_DATE);
		DataFrame actionDF;
		if(startDate !=null && endDate != null){
			String sql =
					"select * "
							+ "from user_visit_action "
							+ "where date>='" + startDate + "' "
							+ "and date<='" + endDate + "'";
//				+ "and session_id not in('','','')"

			actionDF = sqlContext.sql(sql);
		}else if(startDate !=null && endDate.isEmpty()){
			String _endDate = DateUtils.getTodayDate().toString();
			String sql =
					"select * "
							+ "from user_visit_action "
							+ "where date>='" + startDate + "' "
							+ "and date<='" + _endDate + "'";
			actionDF = sqlContext.sql(sql);
		}else {
			String sql =
					"select * "
							+ "from user_visit_action ";
			actionDF = sqlContext.sql(sql);
		}

		
		/**
		 * 这里就很有可能发生上面说的问题
		 * 比如说，Spark SQl默认就给第一个stage设置了20个task，但是根据你的数据量以及算法的复杂度
		 * 实际上，你需要1000个task去并行执行
		 * 
		 * 所以说，在这里，就可以对Spark SQL刚刚查询出来的RDD执行repartition重分区操作
		 */
		
//		return actionDF.javaRDD().repartition(1000);
		
		return actionDF.javaRDD();
	}
	
}
